function results = nnsparafac(x,k,sumabs,nonneg,orth,init)

sumabsa = sumabs(1);
sumabsb = sumabs(2);
sumabsc = sumabs(3);
nonnega = nonneg(1);
nonnegb = nonneg(2);
nonnegc = nonneg(3);

maxit = 100; %500;
convth = 1e-4;

[I,J,K] = size(x);

XA = reshape(x,I,J*K);
idnan = isnan(XA);
XB = reshape(permute(x,[2 1 3]),J,I*K);
XC = reshape(permute(x,[3 1 2]),K,I*J);
E = XA(~idnan);
sse0 = sum(E(:).^2);

% initialize A, B and C;
if nargin<6;
    A = rand(I,k);
    Da = eye(k);
    B = rand(J,k);
    C = rand(K,k);
else
    A = init{1};
    Da = eye(k);
    B = init{2};
    C = init{3};
end
% normalize
B = B*diag(1./sqrt(diag(B'*B)));
C = C*diag(1./sqrt(diag(C'*C)));
oldfit = 0;
diff = 1;
fit = 0;
c = 0;
smrmaxit = 10;
while (diff>convth || fit<0) && c<maxit
    c = c+1;
    % update A
    for i=1:k;
        BC(:,i) = kron(C(:,i),B(:,i));
    end
    [A Da la] = SMRupdate3(XA,A*Da,BC,nonnega,sumabsa,orth(1),smrmaxit);
    %     [A Da la] = SMRupdateBIC(XA,A*Da,BC,nonnega,sumabsa,orth(1),smrmaxit);
    % update B
    for i=1:k;
        AC(:,i) = kron(C(:,i),A(:,i)*Da(i,i));
    end
    [B Db lb] = SMRupdate3(XB,B,AC,nonnegb,sumabsb,orth(2),smrmaxit);
    %     [B Db lb] = SMRupdateBIC(XB,B,AC,nonnegb,sumabsb,orth(2),smrmaxit);
    
    % update C
    for i=1:k;
        AB(:,i) = kron(B(:,i),A(:,i)*Da(i,i));
    end
    if sumabsc<0; % Do BIC
        [C Dc lc] = SMRupdateBIC(XC,C,AB,nonnegc,sumabsc,orth(3),smrmaxit);
    else
        [C Dc lc] = SMRupdate3(XC,C,AB,nonnegc,sumabsc,orth(3),smrmaxit);
    end
    
    
    E = XA-A*Da*krb(C,B)';
    E = E(~idnan);
    sse = sum(E(:).^2);
    fit = 1 - sse/sse0;
    diff = fit - oldfit;
    FIT(c) = fit;
    %     plot(FIT); shg
    %     if ceil(c/100) == c/100 | diff<convth;
%     disp(['Iteration ' num2str(c) ' - Fit = ' num2str(fit*100,2) '%'])
    %     end
    oldfit = fit;
end
results.fittot = fit;
results.sse = sse;
results.sse0 = sse0;
results.difffitend = diff;

for i=1:k;
    idnan = isnan(XA);
    E = XA-A(:,i)*Da(i,i)*krb(C(:,i),B(:,i))';
    E = E(~idnan);
    sse1 = sum(E(:).^2);
    fit(i) = 1 - sse1/sse0;
end
% plot(FIT,'-'); shg
results.loads{1} = A*Da;
results.loads{2} = B;
results.loads{3} = C;
results.numiter = c;
results.fit = fit;
results.lambda{1} = la;
results.lambda{2} = lb;
results.lambda{3} = lc;
results.Date = datestr(now);
%%%%%%% internal functions %%%%%%%%%%%%%%%

% function [A Da] = SMRupdate(X,B,C,nonneg,sumabsa)
% % X (I,JK) is the unfolded matrix of x (I,J,K);
% % B (J,k) is loads of second mode
% % C (K,k) is loads of third mode
% % nonneg is {0,1} for nonnegativity
% % sumabsa is the maximum L1 norm of of L1 normalized vectors of A = [a1,..,ak];
% % 1 < sumabs < sqrt(I);
%
% k = size(B,2);
%
% % calculate least squares fit
% for i=1:k;
%     BC(:,i) = kron(C(:,i),B(:,i));
% end
% A = X*BC*pinv(BC'*BC);
%
% for i=1:k;
%     % find lambda fullfilling sum(abs(a))<=sumabsa
%     l = searchL(A(:,i),sumabsa,nonneg);
%     A(:,i) = softth(A(:,i),l,nonneg);
% end
% % normalize
% Da = diag(sqrt(diag(A'*A)));
% Dainv = diag(1./sqrt(diag(A'*A)));
% A = A*Dainv;

function xst = softth(x,l,nonneg)
% soft threshold of x by l;
if nargin==2;
    nonneg=0;
end
xst=sign(x).*max(0, abs(x)-l);
if nonneg==1;
    xst(xst<0) = 0;
end

function l=searchL(t,sumabst,nonneg)

if norm(t,2)==0 || sum(abs(t./norm(t,2)))<=sumabst
    l=0;
    return
end
l1 = 0;
l2 = max(abs(t))-1e-4;

iter = 1;
maxit = 1000;
while iter < maxit
    % make a new guess
    lnew = (l1+l2)/2;
    stnew = softth(t,lnew,nonneg);
    % check it
    if sum(abs(stnew/norm(stnew,2)))<sumabst
        l2 = lnew;
    else
        l1= lnew;
    end
    if (l2-l1)<1e-5
        l=lnew;
        return
    end
    iter = iter+1;
end


function AB = krb(A,B)
%KRB Khatri-Rao product
[I,F]=size(A);
[J,F1]=size(B);

if F~=F1
    error(' Error in krb.m - The matrices must have the same number of columns')
end

AB=zeros(I*J,F);
for f=1:F
    ab=B(:,f)*A(:,f).';
    AB(:,f)=ab(:);
end